import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import argparse

# Set up argument parser
parser = argparse.ArgumentParser(description='Plot HBM bandwidth from CSV data')
parser.add_argument('csv_file', type=str, help='Input CSV file path')
args = parser.parse_args()

# Read the CSV file
data = pd.read_csv(args.csv_file)

# Create a figure with appropriate size
plt.figure(figsize=(12, 8))

# Define the columns to plot
columns_to_plot = ['vs1_u1', 'vs1_u2', 'vs1_u4', 
                   'vs2_u1', 'vs2_u2', 'vs2_u4',
                   'vs4_u1', 'vs4_u2', 'vs4_u4']

# Define legend labels
legend_labels = {
    'vs1_u1': 'VectorSize=1, Unroll=1',
    'vs1_u2': 'VectorSize=1, Unroll=2',
    'vs1_u4': 'VectorSize=1, Unroll=4',
    'vs2_u1': 'VectorSize=2, Unroll=1',
    'vs2_u2': 'VectorSize=2, Unroll=2',
    'vs2_u4': 'VectorSize=2, Unroll=4',
    'vs4_u1': 'VectorSize=4, Unroll=1',
    'vs4_u2': 'VectorSize=4, Unroll=2',
    'vs4_u4': 'VectorSize=4, Unroll=4'
}

# Plot each column
for column in columns_to_plot:
    plt.plot(data['data_size_kb'], data[column], marker='o', linewidth=2, label=legend_labels[column])

# Set x-axis to log scale
plt.xscale('log')

# Add grid
plt.grid(True, which="both", ls="-", alpha=0.2)

# Add labels and title
plt.xlabel('Data Size (KB)', fontsize=14)
plt.ylabel('Bandwidth (GB/s)', fontsize=14)
plt.title('H100 HBM Absolute Kernel Bandwidth', fontsize=16)

# Add legend
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=12)

# Adjust layout
plt.tight_layout()

# Save the figure
plt.savefig('hbm_bandwidth.png', dpi=300, bbox_inches='tight')

print("Plot saved as hbm_bandwidth.png") 